#!/usr/bin/env python
# coding: utf-8


#Import libraries
import pandas as pd
import numpy as np
import sys
import pysam
import re
import matplotlib.pyplot as plt
from collections import defaultdict
import logging
import os
import contextlib
import json

from contig_data import ContigData

# Capture the command-line arguments passed in shell script
args = sys.argv[1:]


#Import Bash Variables
sample_name = args[0]
outdir = args[1]
fasta_filename = args[2]
contig_file = args[3]
discard_deletions = args[4]
desired_contig = args[5] # will be used to output dataframe of counts


#Convert Aligned file into Mutations Text file 
output_filename = outdir + "/" + sample_name + "/Experiment_Clean.txt"
log_file = outdir + "/" + sample_name + "/Experiment_log.log"
text_file_path = outdir + "/" + sample_name + "/text.txt"
#text_filename = outdir + "/" + sample_name + "/text.txt"
plot_path = outdir + "/" + sample_name
summary_path  = outdir + "/" + sample_name + "/Experiment_summary.txt"
json_summary = outdir + "/" + sample_name + "/experiment_counts.json"
data_df_out = outdir + "/" + sample_name + "/Check.txt"


#Define all functions


def analyze_fasta(fasta_filename, contig_ranges=None):
    """
    Analyze a FASTA file to identify and return positions of consecutive identical bases within the sequences.
    The function also filters these positions based on provided contig ranges and converts all sequences to uppercase.

    Parameters:
    fasta_filename (str): Path to the FASTA file.
    contig_ranges (dict, optional): Dictionary containing contig names as keys and tuples with range start and end positions as values.
                                    Positions outside these ranges are ignored. If None, no filtering is applied.

    Returns:
    dict: A dictionary containing sequence names as keys and lists of positions with consecutive identical bases as values.
    """

    sequences = {}  # Dictionary to store sequence data

    # Read the FASTA file, store sequences, and convert them to uppercase
    with open(fasta_filename, 'r') as file:
        current_sequence = None
        for line in file:
            line = line.strip()
            if line.startswith('>'):  # Identifier line
                current_sequence = line[1:]
                if current_sequence in sequences:
                    raise ValueError(f"Duplicate sequence identifier found: {current_sequence}")
                sequences[current_sequence] = ''
            elif current_sequence is not None:  # Sequence line
                # Convert the sequence to uppercase and accumulate
                sequences[current_sequence] += line.upper()
            else:
                raise ValueError("File format error: Data encountered before sequence identifier.")

    consecutive_positions_dict = {}  # Nested dictionary to store consecutive positions

    for sequence_name, sequence in sequences.items():
        sequence_positions = []
        prev_base = None
        current_streak = []  # Hold the positions of the current streak of identical bases

        # Processing each base in sequence to find consecutive identical bases
        for pos, base in enumerate(sequence, start=1):  # Start at position 1
            # No need for base.upper() here anymore as sequences are already in uppercase

            # If same base as previous, we continue the streak
            if base == prev_base:
                current_streak.append(pos)
            else:
                # If streak ends (change of base), and it had more than one base, we store the positions
                if len(current_streak) > 1:
                    sequence_positions.extend(current_streak)
                current_streak = [pos]  # New streak starts with the current base
                prev_base = base

        # Don't forget the last streak in the sequence
        if len(current_streak) > 1:
            sequence_positions.extend(current_streak)

        # If contig_ranges is provided, we filter the positions based on the range
        if contig_ranges and sequence_name in contig_ranges:
            range_start, range_end = contig_ranges[sequence_name]
            sequence_positions = [pos for pos in sequence_positions if range_start <= pos <= range_end]

        consecutive_positions_dict[sequence_name] = sequence_positions

    return consecutive_positions_dict


def analyze_mutations(positions, bases, valid_substitution_positions, valid_deletion_positions, homopolymer_positions, discard_deletions=False):
    """
    Analyze the positions and bases to categorize the type of mutations, identifying the occurrence of perfect matches,
    insertions, deletions (both valid and ambiguous), and substitutions. This function also incorporates logic to decide 
    whether to discard a read entirely based on the presence of deletions and a flag indicating such preference.

    Parameters:
    positions (list): A list of integers representing the positions in the read where mutations occur.
    bases (list): A list of strings representing the types of mutations (e.g., 'Insertion', 'Deletion') at each position in 'positions'.
    valid_substitution_positions (set): A set of integers representing positions where substitutions are considered valid.
    valid_deletion_positions (set): A set of integers representing positions where deletions are considered valid.
    homopolymer_positions (set): A set of integers representing positions that are part of homopolymer regions.
    discard_deletions (bool): Optional; A flag indicating whether reads with deletions should be discarded. Defaults to False.

    Returns:
    dict: A dictionary containing detailed categorizations of mutations including a flag indicating if the read should be discarded.
    """
    mutation_details = {
        'perfect_match': False,
        'insertions': [],
        'ambiguous_deletions': [],
        'deletions': [],
        'substitutions': [],
        'discard': False  # Additional flag to indicate if the read should be discarded
    }

    # If there are no mutations (this depends on what parse_cs_tag returns when there's no mutation)
    if not positions and not bases:
        mutation_details['perfect_match'] = True
        return mutation_details

    for position, base in zip(positions, bases):
        if base == 'Deletion':
            # New logic to handle the discard flag for deletions
            if discard_deletions:
                mutation_details['discard'] = True
                return mutation_details  # Early exit, as we're discarding this read
            elif position in homopolymer_positions:
                # If the deletion is within a homopolymer region, it's ambiguous
                mutation_details['ambiguous_deletions'].append(position)
            elif position in valid_deletion_positions:
                # It's a regular deletion and not in a homopolymer region
                mutation_details['deletions'].append(position)
        elif base == 'Insertion':
            mutation_details['insertions'].append(position)
        else:
            # For substitutions, we check if they're in the valid positions
            if position in valid_substitution_positions:
                mutation_details['substitutions'].append((position, base))

    # Check if we found any significant mutations. If none, mark as perfect match.
    if not any(mutation_details.values()):  # Check if all categories are empty
        mutation_details['perfect_match'] = True

    return mutation_details


def extract_and_analyze_reads(text_file_path, contig_data, consecutive_positions, output_text=False, text_filename=None, discard_deletions=False):
    """
    Extracts and analyzes reads from a tab-delimited text file, focusing on specified data columns.

    Parameters:
    text_file_path (str): Path to the input text file. The file should be tab-delimited.
    contig_data (ContigData): An instance of the ContigData class for storing read analysis results.
    consecutive_positions (dict): A dictionary containing information about homopolymer regions for each contig.
    output_text (bool): If True, the function will output a text file with reads' details.
    text_filename (str): The path where the text file will be saved if output_text is True.

    Returns:
    None: The function performs analysis and updates the contig_data instance with the results. Optionally writes to text files.
    """
    print("Starting analysis...")
    logging.info("Initiating text file processing.")
    logging.info(f"Starting analysis on {text_file_path}")

    try:
        # Check if outputting to a text file is enabled
        if output_text and text_filename:
            # Open the text file for writing the output
            with open(text_filename, 'w') as out_file:
                headers = ["Read_Name", "Contig", "Bases", "Positions", "Sequence", "Tabulated"]
                out_file.write('\t'.join(headers) + '\n')
                # Process the text file and write details to the output text file
                process_text_file(text_file_path, contig_data, consecutive_positions, out_file_path=text_filename, discard_deletions=discard_deletions)
        else:
            # Process the text file without writing to a text file
            process_text_file(text_file_path, contig_data, consecutive_positions)

    except FileNotFoundError:
        logging.error(f"File not found: {text_file_path}")
    except ValueError as e:
        logging.error(f"Data processing error in {text_file_path}: {e}")
    except Exception as e:
        logging.error(f"Unexpected error in {text_file_path}: {e}", exc_info=True)
    finally:
        logging.info(f"Completed analysis on {text_file_path}")
        
        
        
def process_text_file(text_file_path, contig_data, consecutive_positions, out_file_path=None, discard_deletions=False):
    with open(text_file_path, 'r') as text_file, \
         open(out_file_path, 'w') if out_file_path else None as out_file: 

        header = text_file.readline()  # Read and discard the header line

        #header = text_file.readline()  # Read and discard the header line
        for line in text_file:
            skip_to_writing = False  # Initialize flag to False
            tabulated = "No"  # Initialize tabulated to "No" by default

            parts = line.strip().split('\t')
            read_name, contig, bases_str, positions_str = parts[0], parts[1], parts[2], parts[3]
            bases = bases_str.split(',')
             # Check if positions_str is empty and handle accordingly
            if positions_str:
                positions = list(map(int, positions_str.split(',')))
            else:
                positions = [] 
            query_sequence = parts[5]

            # Fetch contig information once
            contig_info = contig_data.get_data().get(contig, {})
            contig_length = contig_info.get('length', 0)
            valid_substitution_positions = contig_info.get('substitutions', set())
            valid_deletion_positions = contig_info.get('deletions', set())
            homopolymer_positions = consecutive_positions.get(contig, set()) 

            # Update the contig_data with mutations found
            mutation_details = analyze_mutations(
            positions, 
            bases, 
            valid_substitution_positions, 
            valid_deletion_positions, 
            homopolymer_positions,
            discard_deletions
            )

            #Discard read if analysis is being run with no-deletion setting
            if mutation_details['discard']:
                contig_data.update_discarded(contig)  # Update the discard counter
                continue  # Skip further processing and move to the next read

            # Handle perfect match scenario
            if mutation_details['perfect_match']:
                contig_data.update_simple_count(contig, 'WT')
                tabulated = "Yes"

            # Handle insertions
            if mutation_details['insertions']:
                contig_data.update_simple_count(contig, 'insertions')
                skip_to_writing = True

            # Handle ambiguous deletions
            if mutation_details['ambiguous_deletions']:
                contig_data.update_simple_count(contig, 'ambiguous_deletions')

            # Handle deletions
            for position in mutation_details['deletions']:
                contig_data.update_deletion(contig, position)
                tabulated = "Yes"

            # Handle substitutions
            for position, base in mutation_details['substitutions']:
                contig_data.update_substitution(contig, position, base)
                tabulated = "Yes"


            if not skip_to_writing:
                if all(value is not None for value in [read_name, contig, bases, positions, query_sequence, tabulated]):
                    # Prepare the data to write
                    read_data = [
                        read_name,
                        contig,
                        ','.join(bases),
                        ','.join(map(str, positions)),
                        query_sequence,
                        tabulated,
                    ]

                    if out_file:  # This is the file object opened for writing
                        out_file.write('\t'.join(read_data) + '\n')
            #else:
                #print(f"Skipped writing read {read_name} due to insertions.")




def extract_contig_data(contig_data):
    """
    Extract data from the ContigData object and structure it as a list of dictionaries.

    :param contig_data: ContigData object containing the mutation data.
    :return: List of dictionaries containing structured mutation data.
    """
    extracted_data = []

    # Retrieve the raw data dictionary from the ContigData object
    raw_data = contig_data.get_data()

    for contig, contig_info in raw_data.items():
        for mutation_type, positions in contig_info.items():
            if mutation_type in ['substitutions', 'deletions']:
                for position, count in positions.items():
                    # For each position, create a dictionary with the relevant data
                    extracted_data.append({
                        'Contig': contig,
                        'Position': position,
                        'Mutation': mutation_type,
                        'Count': count
                    })
            elif mutation_type in ['insertions', 'ambiguous_deletions', 'WT', 'No_CS']:
                # These are global counts for the contig and not position-specific
                extracted_data.append({
                    'Contig': contig,
                    'Mutation': mutation_type,
                    'Count': positions  # Here, 'positions' is actually the count for these mutation types
                })

    return extracted_data


def summarize_mutation_data(contig_data, output_file_path):
    """
    Summarize the mutation data for all contigs and save it to a text file.

    Parameters:
    contig_data (ContigData): The data container with all the mutation information.
    output_file_path (str): The path to the file where the summary will be saved.

    Returns:
    None: The function writes the summary to a file.
    """
    data = contig_data.get_data()  # Retrieve the data from your ContigData object

    # Open the output file in write mode
    with open(output_file_path, 'w') as outfile:
        # Write the header for the CSV
        headers = [
            'Contig', 'Total_Reads', 'Insertions', 'Deletions', 
            'Substitutions', 'Ambiguous_Deletions', 'WT', 'No_CS'
        ]
        outfile.write(','.join(headers) + '\n')

        # Iterate over each contig and collect the necessary data
        for contig, contig_info in data.items():
            # Calculate the total reads for the contig (assuming 'No_CS' is the only non-mutation-related count)
            total_reads = sum([count for key, count in contig_info.items() if isinstance(count, int) and key != 'No_CS'])

            # Safely get the total counts for insertions, deletions, and substitutions
            insertions = contig_info.get('insertions', 0)
            deletions = sum(contig_info.get('deletions', {}).values())
            substitutions = sum(contig_info.get('substitutions', {}).values())

            # Extract other counts, ensuring we default to 0 if the key is not found
            ambiguous_deletions = contig_info.get('ambiguous_deletions', 0)
            wt = contig_info.get('WT', 0)
            no_cs = contig_info.get('No_CS', 0)

            # Prepare the line of data to be written to the file
            line_data = [
                contig,
                str(total_reads),
                str(insertions),
                str(deletions),
                str(substitutions),
                str(ambiguous_deletions),
                str(wt),
                str(no_cs)
            ]
            outfile.write(','.join(line_data) + '\n')

# Define a function to safely perform the division.
def safe_divide(row):
    numerator = row['Count']
    denominator = wt_counts.loc[row['Contig']]['Count']
    
    # Check if the denominator is zero or NaN.
    if denominator == 0 or np.isnan(denominator):
        return 0.0  # or np.nan, depending on how you want to handle this case
    else:
        return np.divide(numerator, denominator)  # Safe division


def convert_df_to_json(filtered_df):
    # Structure the DataFrame data into a dictionary format
    structured_data = {
        'deletions': {},
        'substitutions': {}
    }

    # Iterate over rows and populate the dictionary
    for _, row in filtered_df.iterrows():
        mutation_type = row['Mutation']
        if mutation_type in ['deletions', 'substitutions']:
            position = int(row['Position'])  # Ensure position is an integer
            # Structure the data
            if position not in structured_data[mutation_type]:
                structured_data[mutation_type][position] = {
                    'normalized_count': row['Normalized_Count'],
                    'raw_count': row['Count']
                }
            else:
                # If there are multiple entries for a position, sum them
                structured_data[mutation_type][position]['normalized_count'] += row['Normalized_Count']
                structured_data[mutation_type][position]['raw_count'] += row['Count']

    # Convert the structured data to JSON
    json_data = json.dumps(structured_data, indent=4)
    return json_data

def convert_to_serializable(data):
    """
    Convert data to a format that is serializable by the json module.
    Specifically, convert numpy data types to Python native data types.
    """
    if isinstance(data, dict):
        return {k: convert_to_serializable(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [convert_to_serializable(v) for v in data]
    elif isinstance(data, np.generic):
        return data.item()  # Use `.item()` instead of `np.asscalar`
    else:
        return data

def save_results_to_json(data, filename):
    """
    Save the data to a JSON file, ensuring that all data types are serializable.
    
    Args:
        data (dict): The data to save.
        filename (str): The name of the file to which the data will be saved.
    """
    serializable_data = convert_to_serializable(data)
    with open(filename, 'w') as file:
        json.dump(serializable_data, file, indent=4)


############# Execute Code #############

#Process Contig_Ranges

contig_ranges = {}

# Read the text file line by line
with open(contig_file, 'r') as file:
    for line in file:
        line = line.strip()  # Remove leading/trailing whitespace and newline characters
        if line:
            parts = line.split()  # Split the line into parts using whitespace as delimiter
            contig_name = parts[0]
            contig_length = int(parts[3])

            # If discard_deletions is True, set the range to be from 1 through contig_length
            if discard_deletions:
                contig_ranges[contig_name] = (1, contig_length, contig_length)
            else:
                start_position = int(parts[1])
                end_position = int(parts[2])
                contig_ranges[contig_name] = (start_position, end_position, contig_length)


## Execute analyze_fasta function
consecutive_positions = analyze_fasta(fasta_filename)

#Initialize ContigData Storage
contigs = ContigData()
contigs.initialize_contigs(contig_ranges, consecutive_positions)


#Setup Logging
logging.basicConfig(
    filename=log_file,  # Make sure this path is accessible in your Jupyter environment
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)


# Create file handler which logs messages to a file
logger = logging.getLogger()  # Creating a root logger


#########Run Main Functions
try:
    # Call your function with necessary arguments
    extract_and_analyze_reads(
        text_file_path=text_file_path, 
        contig_data=contigs, 
        consecutive_positions=consecutive_positions, 
        output_text=True, 
        text_filename=output_filename,
        discard_deletions=discard_deletions
    )


except Exception as e:
    logging.error(f"An error occurred: {e}", exc_info=True)
    # You could handle specific exceptions or re-raise them if you want the notebook to stop execution.

# Get data from your ContigData object (assuming it's named 'contigs')
data_list = extract_contig_data(contigs)

# Convert the list of dictionaries into a Pandas DataFrame
data_df = pd.DataFrame(data_list)

#Filter dataframe to contain only desired_contig
filtered_df = data_df[data_df['Contig'] == desired_contig].copy()

#Temprary check
filtered_df.to_csv(data_df_out, sep='\t', index=False)

#Make plots
# Define a directory where plots will be saved
if not os.path.exists(plot_path):
    os.makedirs(plot_path)  # Create directory if it doesn't exist

# Assuming 'data_df' is your DataFrame containing the data
mutation_types = ['deletions', 'substitutions']

# Calculate the WT counts separately as we use it for rate calculation repeatedly
wt_counts = data_df[data_df['Mutation'] == 'WT'][['Contig', 'Count']].set_index('Contig')


# Assuming 'WT' count is stored as a separate entry in your DataFrame
contig_wt_count = filtered_df[filtered_df['Mutation'] == 'WT']['Count'].iloc[0]

# Apply normalization
filtered_df['Normalized_Count'] = filtered_df.apply(lambda row: row['Count'] / contig_wt_count, axis=1)


# First, let's handle the individual mutation plots
for mutation in mutation_types:
    # Filter data for the current mutation type
    mutation_data = data_df[data_df['Mutation'] == mutation]

    # Process the mutation data
    grouped_mutations = mutation_data.groupby(['Contig', 'Position']).sum(numeric_only=True).reset_index()

    # Calculate the rate
    grouped_mutations['Rate'] = grouped_mutations.apply(safe_divide, axis=1)

    # Plotting for each contig, for the current mutation type
    for contig in grouped_mutations['Contig'].unique():
        contig_data = grouped_mutations[grouped_mutations['Contig'] == contig]

        plt.figure(figsize=(10, 6))
        plt.plot(contig_data['Position'], contig_data['Rate'], marker='o', linestyle='-', color='b', label=f"{mutation.capitalize()} Rate")
        plt.title(f'{mutation.capitalize()} Mutation Rate in {contig}')
        plt.xlabel('Position')
        plt.ylabel('Rate')
        plt.grid(True)
        plt.legend()

        # Save the figure
        filename = f"{contig}_{mutation}_rate.png"
        file_path = os.path.join(plot_path, filename)

        plt.savefig(file_path)
        plt.close()  # Close the figure to free up memory

# Now, for the combined plot
all_mutations = pd.concat([data_df[data_df['Mutation'] == m] for m in mutation_types])
grouped_all_mutations = all_mutations.groupby(['Contig', 'Position']).sum(numeric_only=True).reset_index()


grouped_all_mutations['Rate'] = grouped_all_mutations.apply(safe_divide, axis=1)



# Plotting for each contig for the combined data
for contig in grouped_all_mutations['Contig'].unique():
    contig_data = grouped_all_mutations[grouped_all_mutations['Contig'] == contig]

    plt.figure(figsize=(10, 6))
    plt.plot(contig_data['Position'], contig_data['Rate'], marker='o', linestyle='-', color='r', label="Combined Rate")
    plt.title(f'Combined Mutation Rate in {contig}')
    plt.xlabel('Position')
    plt.ylabel('Rate')
    plt.grid(True)
    plt.legend()

    # Save the figure
    filename = f"{contig}_combined_mutation_rate.png"
    file_path = os.path.join(plot_path, filename)

    plt.savefig(file_path)
    plt.close()  # Close the figure to free up memory



# Use the function with your ContigData object and the path where you want to save the summary
summarize_mutation_data(contigs, summary_path)


# Assuming filtered_df is your DataFrame with normalized counts
json_output = convert_df_to_json(filtered_df)

# Optionally, write JSON to a file
with open(json_summary, 'w') as file:
    file.write(json_output)



# Use the function to get the normalized stats
#normalized_statistics = normalize_data_by_wt(data_df)

# Convert normalized stats for serialization
#normalized_statistics_serializable = convert_to_serializable(normalized_statistics)


# Save the summary statistics to a JSON file
#save_results_to_json(normalized_statistics_serializable, json_summary)



